
import os
import torch
from tqdm import tqdm
from PIL import Image
from torchvision import transforms

preprocess = transforms.Compose([
    transforms.Resize((224, 224)),
    transforms.ToTensor(),
    transforms.Normalize(
        mean=[0.485, 0.456, 0.406],
        std=[0.229, 0.224, 0.225]
    ),
])


_SCALE = 1000
def compute_mmd(fx, fy, sigma=10):
    # Reshape features if necessary
    fx = fx.view(fx.size(0), -1)
    fy = fy.view(fy.size(0), -1)

    # Define kernel parameter gamma
    gamma = 1 / (2 * sigma ** 2)

    # Compute kernel matrices
    K_xx = compute_kernel(fx, fx, gamma)
    K_yy = compute_kernel(fy, fy, gamma) 
    K_xy = compute_kernel(fx, fy, gamma)

    # Calculate MMD
    mmd = K_xx.mean() + K_yy.mean() - 2 * K_xy.mean()
    return mmd *_SCALE
def compute_kernel(x, y, gamma):
    # Compute pairwise squared Euclidean distances
    x_norm = (x ** 2).sum(dim=1).view(-1, 1)
    y_norm = (y ** 2).sum(dim=1).view(1, -1)
    dist = x_norm + y_norm - 2.0 * torch.mm(x, y.t())
    
    # Compute the RBF kernel
    dist = torch.clamp(dist, min=1e-12)

    K = torch.exp(-gamma * dist)
    return K

def approximate_mmd(fx, fy, sigma=10, D=2000):
    # Ensure fx and fy are 2D
    fx = fx.view(fx.size(0), -1)
    fy = fy.view(fy.size(0), -1)
    
    # Subsample features if necessary
    if fx.size(0) > D:
        fx = fx[torch.randperm(fx.size(0))[:D]]
    if fy.size(0) > D:
        fy = fy[torch.randperm(fy.size(0))[:D]]
    
    # Kernel parameter gamma
    gamma = 1 / (2 * sigma ** 2)

    # Approximate kernel matrices
    K_xx = compute_kernel(fx, fx, gamma)
    K_yy = compute_kernel(fy, fy, gamma)
    K_xy = compute_kernel(fx, fy, gamma)

    # Calculate MMD
    mmd = K_xx.mean() + K_yy.mean() - 2 * K_xy.mean()
    return mmd

def save_images(images, base_output_dir, gender_disease, batch_idx):
    output_dir =base_output_dir
    if not os.path.exists(output_dir):
        os.makedirs(output_dir)
    for idx, img in enumerate(images):
        img.save(os.path.join(output_dir, f'{gender_disease}_{batch_idx}_{idx}.png'))

def compute_reference_features(image_paths, feature_extractor,device):
    features = []
    for img_path in tqdm(image_paths):
        img = Image.open(img_path).convert('RGB')
        img_tensor = preprocess(img).to(device)
        with torch.no_grad():
            feature = feature_extractor(img_tensor.unsqueeze(0))
        features.append(feature.cpu())
    if features:
        return torch.cat(features, dim=0)
    else:
        return torch.tensor([])
